// Copyright 2018 The Cockroach Authors.
// Copyright (c) 2022-present, Shanghai Yunxi Technology Co, Ltd. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// This software (KWDB) is licensed under Mulan PSL v2.
// You can use this software according to the terms and conditions of the Mulan PSL v2.
// You may obtain a copy of Mulan PSL v2 at:
//          http://license.coscl.org.cn/MulanPSL2
// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
// EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
// MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
// See the Mulan PSL v2 for more details.

package main

import (
	"bytes"
	"errors"
	"flag"
	"fmt"
	"go/format"
	"io"
	"os"

	"gitee.com/kwbasedb/kwbase/pkg/sql/opt/optgen/lang"
)

type genFunc func(compiled *lang.CompiledExpr, w io.Writer)

var (
	errInvalidArgCount     = errors.New("invalid number of arguments")
	errUnrecognizedCommand = errors.New("unrecognized command")
)

var (
	out = flag.String("out", "", "output file name of generated code")
)

// useGoFmt controls whether generated code is formatted by "go fmt" before
// being output.
const useGoFmt = true

func main() {
	flag.Usage = usage
	flag.Parse()

	args := flag.Args()
	if len(args) < 2 {
		flag.Usage()
		fatal(errInvalidArgCount)
	}

	cmd := args[0]
	switch cmd {
	case "exprs":
	case "ops":

	default:
		flag.Usage()
		fatal(errUnrecognizedCommand)
	}

	sources := flag.Args()[1:]
	compiler := lang.NewCompiler(sources...)
	compiled := compiler.Compile()
	if compiled == nil {
		for i, err := range compiler.Errors() {
			if i >= 10 {
				fmt.Fprintf(os.Stderr, "... too many errors (%d more)\n", len(compiler.Errors()))
				break
			}

			fmt.Fprintf(os.Stderr, "%v\n", err)
		}
		os.Exit(2)
	}

	var err error
	switch cmd {
	case "exprs":
		var gen exprsGen
		err = generate(compiled, *out, gen.generate)

	case "ops":
		err = generate(compiled, *out, generateOps)
	}

	if err != nil {
		fatal(err)
	}
}

// usage is a replacement usage function for the flags package.
func usage() {
	fmt.Fprintf(os.Stderr, "LangGen generates the AST for the Optgen language.\n\n")

	fmt.Fprintf(os.Stderr, "LangGen uses the Optgen definition language to generate its own AST.\n")

	fmt.Fprintf(os.Stderr, "Usage:\n")

	fmt.Fprintf(os.Stderr, "\tlanggen [flags] command sources...\n\n")

	fmt.Fprintf(os.Stderr, "The commands are:\n\n")
	fmt.Fprintf(os.Stderr, "\texprs  generate expression definitions and functions\n")
	fmt.Fprintf(os.Stderr, "\tops    generate operator definitions and functions\n")
	fmt.Fprintf(os.Stderr, "\n")

	fmt.Fprintf(os.Stderr, "Flags:\n")

	flag.PrintDefaults()

	fmt.Fprintf(os.Stderr, "\n")
}

func fatal(err error) {
	fmt.Fprintf(os.Stderr, "ERROR: %v\n", err)
	os.Exit(2)
}

func generate(compiled *lang.CompiledExpr, out string, genFunc genFunc) error {
	var buf bytes.Buffer

	buf.WriteString("// Code generated by langgen; DO NOT EDIT.\n\n")
	fmt.Fprintf(&buf, "package lang\n\n")

	genFunc(compiled, &buf)

	var b []byte
	var err error

	if useGoFmt {
		b, err = format.Source(buf.Bytes())
		if err != nil {
			// Write out incorrect source for easier debugging.
			b = buf.Bytes()
			err = fmt.Errorf("code formatting failed with Go parse error\n%s:%s", out, err)
		}
	} else {
		b = buf.Bytes()
	}

	var writer io.Writer
	if out != "" {
		file, err := os.Create(out)
		if err != nil {
			fatal(err)
		}

		defer file.Close()
		writer = file
	} else {
		writer = os.Stderr
	}

	if err != nil {
		// Ignore any write error if another error already occurred.
		_, _ = writer.Write(b)
	} else {
		_, err = writer.Write(b)
	}

	return err
}
